from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
from peft import PeftModel, AutoPeftModelForCausalLM

class ModelUtils(object):

    @classmethod
    def load_model(cls, model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
        # 是否使用4bit量化进行推理
        if load_in_4bit:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
            )
        else:
            quantization_config = None

        if "full" in adapter_name_or_path:
            model = AutoModelForCausalLM.from_pretrained(
                adapter_name_or_path,
                # load_in_4bit=load_in_4bit,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
                device_map='auto',
                quantization_config=quantization_config
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                # load_in_4bit=load_in_4bit,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
                device_map='auto',
                quantization_config=quantization_config
            )
            model = PeftModel.from_pretrained(model, adapter_name_or_path, torch_dtype=torch.float16)
            # model = AutoPeftModelForCausalLM.from_pretrained(adapter_name_or_path).to('cuda')
        print("successefully loaded model and adapter")
        return model